Program Listing for File lanenet_node.py

Return to documentation for file (codes/lanekerbnetros/lanenet_node.py)

#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# @Time    : 17-05-2019
# @Author  : Zhou Hui
# @Original site    : https://github.com/MaybeShewill-CV/lanenet-lane-detection
# @File    : lanenet_node.py

"""
A ros node for lane and kerb detection
"""

import time
import math
import tensorflow as tf
import numpy as np
import cv2
import os

from lanenet_model import lanenet_merge_model
from lanenet_model import lanenet_cluster
from lanenet_model import lanenet_postprocess
from config import global_config

import rospy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge, CvBridgeError


CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]


class lanenet_detector():


    def __init__(self):

        self.image_topic = rospy.get_param('~image_topic')
        self.output_image = rospy.get_param('~output_image')
        self.output_lane = rospy.get_param('~output_lane')
        self.weight_path = rospy.get_param('~weight_path')
        self.use_gpu = rospy.get_param('~use_gpu')
        self.save_dir = rospy.get_param('~save_dir')
        self.framecount = 0
        self.frame_show_ratio = 1
        self.frame_save_ratio = 5
        self.frame_2 = np.zeros((256, 512*2, 3), np.uint8)
        self.nWaitTime = 1
        self.s_nWaitTime = 0
        self.savecount = 0

        self.init_lanenet()
        self.bridge = CvBridge()
        sub_image = rospy.Subscriber(self.image_topic, Image, self.img_callback, queue_size=1)
        self.pub_image = rospy.Publisher(self.output_image, Image, queue_size=1)


    def init_lanenet(self):


        self.input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')
        phase_tensor = tf.constant('test', tf.string)
        net = lanenet_merge_model.LaneNet(phase=phase_tensor, net_flag='vgg')
        self.binary_seg_ret, self.instance_seg_ret = net.inference(input_tensor=self.input_tensor, name='lanenet_model')

        self.cluster = lanenet_cluster.LaneNetCluster()
        self.postprocessor = lanenet_postprocess.LaneNetPoseProcessor()

        saver = tf.train.Saver()
        # Set sess configuration
        if self.use_gpu:
            sess_config = tf.ConfigProto(device_count={'GPU': 1})
        else:
            sess_config = tf.ConfigProto(device_count={'CPU': 0})
        sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
        sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
        sess_config.gpu_options.allocator_type = 'BFC'

        self.sess = tf.Session(config=sess_config)
        saver.restore(sess=self.sess, save_path=self.weight_path)


    def img_callback(self, data):

        try:
            cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        except CvBridgeError as e:
            print(e)
        image = cv2.resize(cv_image, (512, 256), interpolation=cv2.INTER_LINEAR)
        copy_image = cv2.resize(cv_image, (512, 256), interpolation=cv2.INTER_LINEAR)
        src_image = image
        image = image - VGG_MEAN
        mask_image = self.inference_net(image, src_image)
        out_img_msg = self.bridge.cv2_to_imgmsg(mask_image, "bgr8")
        self.pub_image.publish(out_img_msg)

        self.frame_2[:image.shape[0],:image.shape[1],:] = copy_image
        self.frame_2[:image.shape[0],image.shape[1]:image.shape[1]*2,:] = mask_image
        if self.framecount % self.frame_show_ratio == 0:
            cv2.imshow("left: input image, right: lane and curb detction; Press 's' to save, ' ' to pause, ESC' to quit", self.frame_2)
            #cv2.waitKey(1)
            key = cv2.waitKey(self.nWaitTime)
            if self.s_nWaitTime == 1:
                if self.framecount % self.frame_save_ratio == 0:
                    cv2.imwrite(os.path.join(self.save_dir, "image_%06i.jpg" %self.savecount), copy_image)
                    cv2.imwrite(os.path.join(self.save_dir, "curb_%06i.jpg" %self.savecount), mask_image)
                    cv2.imwrite(os.path.join(self.save_dir, "result_%06i.jpg" %self.savecount), self.frame_2)
                    print('savecount: {:0>6d}'.format(self.savecount))
                    self.savecount = self.savecount + 1
            if key == 27: # type 'ESC'
                if not rospy.is_shutdown():
                    rospy.signal_shutdown('Quit')
            elif key == ord('s'): # type 's'
                self.s_nWaitTime = not self.s_nWaitTime
                if self.s_nWaitTime:
                    print('Start saving images..., frame_save_ratio = {}'.format(self.frame_save_ratio))
                else:
                    print('End...')
            elif key == 32: # type 'space bar'
                self.nWaitTime = not self.nWaitTime
                if self.nWaitTime:
                    print('Unpaused...')
                else:
                    print('Paused...')
        print('framecount: {:0>6d}'.format(self.framecount))
        self.framecount = self.framecount + 1

    def preprocessing(self, img):

        image = cv2.resize(img, (512, 256), interpolation=cv2.INTER_LINEAR)
        image = image - VGG_MEAN
        return image

    def inference_net(self, img, original_img):

        binary_seg_image, instance_seg_image = self.sess.run([self.binary_seg_ret, self.instance_seg_ret],
                                                        feed_dict={self.input_tensor: [img]})
        binary_seg_image[0] = self.postprocessor.postprocess(binary_seg_image[0])
        mask_image = self.cluster.get_lane_mask(binary_seg_ret=binary_seg_image[0],
                                                instance_seg_ret=instance_seg_image[0],
                                                source_image=original_img)
        mask_image_roadf = self.cluster.get_curb_mask(binary_seg_ret=binary_seg_image[0],
                                                instance_seg_ret=instance_seg_image[0],
                                                source_image=mask_image)
        #mask_image = cv2.addWeighted(original_img, 0.6, mask_image, 0.4, 0)
        return mask_image_roadf


if __name__ == '__main__':
    # init args
    rospy.init_node('lanenet_node')
    lanenet_detector()
    rospy.spin()